{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.定义所需函数\n",
    "**(1)sigmoid函数**\n",
    "$$ g(z) = \\frac{1}{1+e^{(-z)}} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(z):\n",
    "    '''sigmoid函数'''\n",
    "    return 1 / (1 + np.exp(-z))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(2)sigmoid函数的导数**\n",
    "$$ g'(z) = \\frac{e^{-z}}{(1+e^{-z})^2} = \\frac{1}{1+e^{-z}} \\cdot \\frac{e^{-z}}{1+e^{-z}} = g(z) \\cdot (1-g(z)) $$\n",
    "$ a^{(l)}_n $ 表示为第 $ l $ 层第 $ n $ 个激活值，即$ a^{(l)}_n = g(z^{(l)}_n) $ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoidDerivative(a):\n",
    "    '''\n",
    "    对sigmoid函数求导，这里传入的是a值!\n",
    "    a是一个array,表示的是第l层的所有激活值\n",
    "    np.multiply()，向量或矩阵对应位置元素相乘，所得维度与原向量/矩阵相同\n",
    "    '''\n",
    "    return np.multiply(a, (1-a))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(3)初始化权值矩阵** <br/>\n",
    "&emsp;&emsp;初始化 $ \\theta^{l}_{ji} $ <br/>\n",
    "&emsp;&emsp;$ j $ 取值为 $ \\{ 0,1,2,...,N_{(S^{l+1})}\\} $ ， $ i $ 取值为 $ \\{ 0,1,2,...,N_{(S^{l})}+ 1 \\}$ <br/>\n",
    "&emsp;&emsp;$ N_{S^{l}} $ 表示第l层神经元个数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def initThetas(hiddenNum, unitNum, inputSize, classNum, epsilon):\n",
    "    '''\n",
    "    初始化权值矩阵\n",
    "    Args:\n",
    "        hiddenNum : 隐藏层数\n",
    "        unitNum : 每个隐藏层的神经元个数\n",
    "        inputSize : 输入层规模\n",
    "        classNum : 分类数目\n",
    "        epsilon : 用来随机初始化权值\n",
    "    Return:\n",
    "        thetas : 权值矩阵\n",
    "    '''\n",
    "    \n",
    "    hiddens = [unitNum for i in range(hiddenNum)]\n",
    "    units = [inputSize] + hiddens + [classNum]\n",
    "    # 例如：units = [3,4,4,3]\n",
    "    thetas = []\n",
    "    for idx, unit in enumerate(units):\n",
    "        if idx == len(units) - 1:\n",
    "            break\n",
    "        nextUnit = units[idx + 1]\n",
    "        \n",
    "        theta = np.random.rand(nextUnit, unit+1) * 2 * epsilon - epsilon\n",
    "        thetas.append(theta)\n",
    "    return thetas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(4)计算代价值函数**\n",
    "$$ J(\\theta) = - \\frac{1}{m}[ \\sum_{i=1}^m \\sum_{k=1}^K y_k^{(i)} \\log(h_\\theta(x^{(i)}))_k + (1-y_k^{(i)})\\log(1-h_\\theta(x^{(i)}))_k]+ \\frac{\\lambda}{2m}\\sum_{l=1}^{L-1} \\sum_{i=1}^{S_l} \\sum_{j=1}^{S_{l+1}} (\\theta_{ji}^{(l)})^2 $$<br/>\n",
    "此处的 $ y_k^{(i)} $ 是指对于第 $ i $ 个样本，分类为 $ k $ 的预测值 <br/><br/>\n",
    "而 $ y^{(i)} $ 为向量 $ [y_1^{(i)},y_2^{(i)},...,y_K^{(i)}] $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def computeCost(thetas, y, theLambda, X=None, a=None):\n",
    "    '''\n",
    "    计算代价\n",
    "    Args:\n",
    "        thetas : 权值矩阵序列\n",
    "        X : 样本\n",
    "        y : 标签集\n",
    "        a : 各层激活值\n",
    "    Returns:\n",
    "        J 预测代价\n",
    "    '''\n",
    "    # m为样本数\n",
    "    m = y.shape[0]  \n",
    "    if a is None:\n",
    "        a = fp(thetas, X)\n",
    "    \n",
    "    # 注意，计算代价的时候，我们只需要关注整个网络的预测和标注之间的差异即可，\n",
    "    # 因此只需要看a[-1]。\n",
    "    # 另外一个注意点是：标注y已经被向量化了，有且仅有一位是1，其他都是0\n",
    "    error = -np.sum(np.multiply(y.T,np.log(a[-1])) + np.multiply((1-y).T, np.log(1-a[-1])))\n",
    "    \n",
    "    # 正则化项,L2正则化\n",
    "    reg = np.sum([np.sum(np.multiply(theta[:,1:],theta[:,1:])) for theta in thetas])\n",
    "    return (1.0/m) * error + (1.0/(2*m)) * theLambda * reg\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**（5）标签向量化、参数展开成向量、参数恢复成矩阵**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def adjustLabels(y):\n",
    "    '''标签向量化\n",
    "\n",
    "    Args:\n",
    "        y : 标签集（数值类型：0,1,2，...,9;此例子时手写数字识别）\n",
    "    Returns:\n",
    "        yAdjusted : 向量化后的标签\n",
    "    '''   \n",
    "    if y.shape[1] == 1:\n",
    "        # np.flatten()返回一份拷贝，对拷贝所做修改不会影响原始矩阵\n",
    "        # 而np.ravel()返回的是视图，修改时会影响原始矩阵 \n",
    "        classes = set(np.ravel(y))\n",
    "        classNum = len(classes)\n",
    "        minClass = min(classes)\n",
    "        if classNum > 2: #多分类，使用向量标注，对应类别位置设置为1\n",
    "            yAdjusted = np.zeros((y.shape[0], classNum), np.float64)\n",
    "            for row, label in enumerate(y):\n",
    "                yAdjusted[row, label - minClass] = 1.0\n",
    "        else: # 二分类\n",
    "            yAdjusted = np.zeros((y.shape[0], 1), np.float64)\n",
    "            for row, label in enumerate(y):\n",
    "                if label != minClass:\n",
    "                    yAdjusted[row,0] = 1.0\n",
    "    return yAdjusted            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def unroll(matrixes):\n",
    "    '''\n",
    "    参数展开\n",
    "    Args:\n",
    "        matrixes 矩阵\n",
    "    Return:\n",
    "        vec 向量\n",
    "    '''\n",
    "    \n",
    "    vec = []\n",
    "    for matrix in matrixes:\n",
    "        vector = matrix.reshape(1,-1)[0]\n",
    "        vec = np.concatenate((vec,vector))\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def roll(vector, shapes):\n",
    "    '''\n",
    "    参数恢复\n",
    "    Args:\n",
    "        vector 向量\n",
    "        shapes shape list\n",
    "    Returns:\n",
    "        matrixes 恢复的矩阵序列    \n",
    "    '''\n",
    "    matrixes = []\n",
    "    begin = 0\n",
    "    for shape in shapes:\n",
    "        end = begin + shape[0] * shape[1]\n",
    "        matrix = vector[begin:end].reshape(shape)\n",
    "        begin = end\n",
    "        matrixes.append(matrix)\n",
    "    return matrixes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(6)前向反馈过程** <br/>\n",
    "用来计算 $ a^l_i $ ，每一层的激活函数值 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fp(theta, X):\n",
    "    '''\n",
    "    前向反馈过程\n",
    "    Args:\n",
    "        Thetas 权值矩阵\n",
    "        X 输入样本\n",
    "    Returns:\n",
    "        a 各层激活向量\n",
    "    '''\n",
    "    # 计算层数layers = [0,1,2,3]\n",
    "    layers = list(range(len(thetas) + 1))\n",
    "    layerNum = len(layers)\n",
    "    # 激活向量序列\n",
    "    a = list(range(layerNum)) # 要的仅仅是定长list结构，内部元素在下面for循环中被赋值\n",
    "    \n",
    "    # 前向传播计算各层输出\n",
    "    for l in layers:\n",
    "        if l == 0: # 输入层, a即为输入值\n",
    "            a[l] = X.T  # 特征值的排列是竖着的\n",
    "        else:\n",
    "            z = thetas[l - 1] * a[l - 1]\n",
    "            a[l] = sigmoid(z)\n",
    "        \n",
    "        # 除输出层外，需要添加偏置\n",
    "        if l != layerNum - 1:\n",
    "            a[l] = np.concatenate((np.ones((1,a[l].shape[1])), a[l]))\n",
    "    return a\n",
    "                "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(7)反向传播过程** <br/>\n",
    "用来计算权值梯度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bp(thetas, a, y, theLambda):\n",
    "    '''反向传播过程\n",
    "\n",
    "    Args:\n",
    "        a 激活值\n",
    "        y 标签\n",
    "    Returns:\n",
    "        D 权值梯度\n",
    "    '''\n",
    "    m = y.shape[0]\n",
    "    layers = list(range(len(thetas) + 1))\n",
    "    layerNum = len(layers)\n",
    "    d = list(range(len(layers)))\n",
    "    delta = [np.zeros(theta.shape) for theta in thetas]\n",
    "    \n",
    "    for l in layers[::-1]: # 反向遍历层\n",
    "        if l == 0:\n",
    "            # 输入层不计算误差\n",
    "            break\n",
    "        if l == layerNum - 1:\n",
    "            # 输出层误差\n",
    "            d[l] = a[l] - y.T\n",
    "        else:\n",
    "            # 忽略偏置\n",
    "            d[l] = np.multiply((thetas[l][:,1:].T * d[l + 1]), sigmoidDerivative(a[l][1:, :]))\n",
    "            \n",
    "    for l in layers[0:layerNum - 1]:\n",
    "        delta[l] += d[l + 1] * (a[l].T)\n",
    "    D = [np.zeros(theta.shape) for theta in thetas]\n",
    "    for l in range(len(thetas)):\n",
    "        theta = thetas[l]\n",
    "        # 偏置更新增量\n",
    "        D[l][:, 0] = (1.0 / m) * (delta[l][0:, 0].reshape(1, -1))\n",
    "        # 权值更新增量\n",
    "        D[l][:, 1:] = (1.0 / m) * (delta[l][0:, 1:] +\n",
    "                                   theLambda * theta[:, 1:])\n",
    "    return D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(8)更新权值**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def updateThetas(m, thetas, D, alpha, theLambda):\n",
    "    \"\"\"更新权值\n",
    "\n",
    "    Args:\n",
    "        m 样本数\n",
    "        thetas 各层权值矩阵\n",
    "        D 梯度\n",
    "        alpha 学习率\n",
    "        theLambda 正规化参数\n",
    "    Returns:\n",
    "        Thetas 更新后的权值矩阵\n",
    "    \"\"\"\n",
    "    for l in range(len(thetas)):\n",
    "        thetas[l] = thetas[l] - alpha * D[l]\n",
    "    return thetas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(9)梯度下降**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradientDescent(thetas, X, y, alpha, theLambda):\n",
    "    \"\"\"梯度下降\n",
    "\n",
    "    Args:\n",
    "        X 样本\n",
    "        y 标签\n",
    "        alpha 学习率\n",
    "        theLambda 正规化参数\n",
    "    Returns:\n",
    "        J 预测代价\n",
    "        Thetas 更新后的各层权值矩阵\n",
    "    \"\"\"\n",
    "    # 样本数，特征数\n",
    "    m, n = X.shape\n",
    "    # 前向传播计算各个神经元的激活值\n",
    "    a = fp(thetas, X)\n",
    "    # 反向传播计算梯度增量\n",
    "    D = bp(thetas, a, y, theLambda)\n",
    "    # 计算预测代价\n",
    "    J = computeCost(thetas,y,theLambda,a=a)\n",
    "    # 更新权值\n",
    "    thetas = updateThetas(m, thetas, D, alpha, theLambda)\n",
    "    if np.isnan(J):\n",
    "        J = np.inf\n",
    "    return J, thetas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(10)梯度校验**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradientCheck(thetas,X,y,theLambda):\n",
    "    \"\"\"梯度校验\n",
    "\n",
    "    Args:\n",
    "        thetas 权值矩阵\n",
    "        X 样本\n",
    "        y 标签\n",
    "        theLambda 正则化参数\n",
    "    Returns:\n",
    "        checked 是否检测通过\n",
    "    \"\"\"\n",
    "    m, n = X.shape\n",
    "    # 前向传播计算各个神经元的激活值\n",
    "    a = fp(thetas, X)\n",
    "    # 反向传播计算梯度增量\n",
    "    D = bp(thetas, a, y, theLambda)\n",
    "    # 计算预测代价\n",
    "    J = computeCost(thetas, y, theLambda, a=a)\n",
    "    DVec = unroll(D)\n",
    "    \n",
    "    # 数值化计算梯度\n",
    "    epsilon = 1e-4 # 注意，这个epsilon的意义\n",
    "    gradApprox = np.zeros(DVec.shape)\n",
    "    ThetaVec = unroll(thetas)\n",
    "    shapes = [theta.shape for theta in thetas]\n",
    "    for i,item in enumerate(ThetaVec):\n",
    "        ThetaVec[i] = item - epsilon\n",
    "        JMinus = computeCost(roll(ThetaVec,shapes),y,theLambda,X=X)\n",
    "        ThetaVec[i] = item + epsilon\n",
    "        JPlus = computeCost(roll(ThetaVec,shapes),y,theLambda,X=X)\n",
    "        gradApprox[i] = (JPlus-JMinus) / (2*epsilon)\n",
    "        \n",
    "    # 平均差距\n",
    "    diff = np.average(gradApprox - DVec)\n",
    "    #print 'gradient checking diff:', diff # 3.21615931121e-06\n",
    "    if diff < 1e-5:\n",
    "        return True\n",
    "    else:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(11)网络训练**<br/>\n",
    "训练时进行迭代，计算最优参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(X, y,checkFlag=False, thetas=None, hiddenNum=0, unitNum=5, epsilon=1, alpha=1, theLambda=0, precision=0.0001, maxIters=50):\n",
    "    \"\"\"网络训练\n",
    "\n",
    "    Args:\n",
    "        X 训练样本\n",
    "        y 标签集\n",
    "        checkFlag 是否进行梯度校验，默认为False，即不进行校验。梯度校验费时\n",
    "        thetas 初始化的Thetas，如果为None，由系统随机初始化Thetas\n",
    "        hiddenNum 隐藏层数目\n",
    "        unitNum 隐藏层的单元数\n",
    "        epsilon 初始化权值的范围[-epsilon, epsilon]\n",
    "        alpha 学习率\n",
    "        theLambda 正规化参数\n",
    "        precision 误差精度\n",
    "        maxIters 最大迭代次数\n",
    "    \"\"\"\n",
    "    # 样本数，特征数\n",
    "    m, n = X.shape\n",
    "    # 标注标签向量化，比如多分类标签要转成向量\n",
    "    y = adjustLabels(y)\n",
    "    classNum = y.shape[1]\n",
    "    # 初始化Theta\n",
    "    if thetas is None:\n",
    "        thetas = initThetas(\n",
    "            inputSize=n,\n",
    "            hiddenNum=hiddenNum,\n",
    "            unitNum=unitNum,\n",
    "            classNum=classNum,\n",
    "            epsilon=epsilon\n",
    "        )\n",
    "        \n",
    "    # 梯度校验\n",
    "    print('Doing Gradient Checking....')\n",
    "    if checkFlag:\n",
    "        checked = gradientCheck(thetas, X, y, theLambda)\n",
    "    else:\n",
    "        checked=True\n",
    "    print('Gradient Checked.')\n",
    "    \n",
    "    if checked:\n",
    "        last_error = np.inf\n",
    "        for i in range(maxIters):\n",
    "            error, thetas = gradientDescent(\n",
    "                thetas, X, y, alpha=alpha, theLambda=theLambda)\n",
    "            if abs(error-last_error) < precision:\n",
    "                last_error = error\n",
    "                break\n",
    "            if error == np.inf:\n",
    "                last_error = error\n",
    "                break\n",
    "            last_error = error\n",
    "        \n",
    "        return {\n",
    "            'error': error,\n",
    "            'thetas': thetas,\n",
    "            'iters': i\n",
    "        }\n",
    "    else:\n",
    "        print('Error: Gradient Cheching Failed!!!')\n",
    "        return {\n",
    "            'error': None,\n",
    "            'Thetas': None,\n",
    "            'iters': 0\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(12)预测函数**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(X, thetas):\n",
    "    \"\"\"预测函数\n",
    "\n",
    "    Args:\n",
    "        X: 样本\n",
    "        thetas: 训练后得到的参数\n",
    "    Return:\n",
    "        a\n",
    "    \"\"\"\n",
    "    a = fp(thetas,X)\n",
    "    return a[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.手写数字数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io import loadmat\n",
    "from matplotlib import pyplot\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = loadmat('data/handwritten_digits.mat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(400,)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data['X'][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAD8CAYAAACLgjpEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEm1JREFUeJzt3X+QXWV9x/HPZze7CQkBEiABEn4NBKYBJSATShnbIBUhUiIW2zDWpooFFWbqTNsR2xnjWKfSaRHHgqBgBmQUtJZAGCIkQ5miHRSSGH4JNJEGWUITIJgQfiW7++0fe5bZ7p4nee49d/feu3m/ZjJ77znfPee52fDhnHuffb6OCAFAmY5mDwBA6yIgACQREACSCAgASQQEgCQCAkASAQEgiYAAkERAAEia0OwBlOnumBT7dUxt9jCAceut/te1q/9t762uJQNiv46pOvOARc0eBjBuPbzj7qy6SrcYts+z/aztjbavKtk/0fYPi/2/sH1MlfMBGFt1B4TtTknXSzpf0lxJl9ieO6zsUkmvRcTxkq6V9E/1ng/A2KtyBTFf0saIeC4idkm6Q9Lw+4JFkm4tHv9Y0jm293rfA6A1VAmIWZJeGPK8p9hWWhMRvZK2Szq4wjkBjKEqb1KWXQkMX1wip2ag0L5M0mWSNKljSoVhAWiUKlcQPZKOHPJ8tqTNqRrbEyQdKGlb2cEi4jsRcXpEnN7t/SoMC0CjVAmIRyXNsX2s7W5JiyWtGFazQtKS4vHFkv4jWMIKaBt132JERK/tKyXdL6lT0rKIeMr2VyStiYgVkr4r6TbbGzVw5bC4EYMGMDYqTZSKiJWSVg7b9qUhj9+W9LEq5wDQPPwuBoAkAgJAEgEBIImAAJBEQABIIiAAJBEQAJIICABJBASAJAICQBIBASCpJRetBerSn/+LwtHbm3nM/uxjurs7u1Yd7bGwGlcQAJIICABJBASAJAICQBIBASCJgACQVKWz1pG2H7T9tO2nbP9VSc0C29ttry/+fKnsWABaU5V5EL2S/joi1tmeKmmt7dUR8athdT+NiAsqnAdAk9R9BRERL0XEuuLx65Ke1sjOWgDaWEPegyi6dp8q6Rclu8+0/Zjtn9g+qRHnAzA2Kk+1tr2/pH+X9PmI2DFs9zpJR0fETtsLJd0laU7iOLTeGyJ27a6hdld2bcfkyfmDaIXpwDVMn/aU/NfWd/SMvGP21XD+Tf+bXRtvv51/3M7mfZZQ6cy2uzQQDt+PiDuH74+IHRGxs3i8UlKX7UPKjkXrPaD1VPkUwxronPV0RHw9UXNYUSfb84vzvVrvOQGMrSq3GGdJ+oSkJ2yvL7b9naSjJCkibtRAP87P2u6V9JakxfTmBNpHld6cP5O0x5vUiLhO0nX1ngNAczGTEkASAQEgiYAAkERAAEgiIAAkERAAkljVegz1v5U/vXbHR07Nrt06P38MJ3xtQ35xDdO9R2tadk3TZmpYgXrjFZ1ZdZ859aHsYy770Yeya4/95lPZtbVMN2/0z4ErCABJBASAJAICQBIBASCJgACQREAASCIgACQREACSCAgAScykbIDY3ZtV1znj0OxjvvNn27Jrf3zybdm1n3tkRH+jpAPu+mV2bcd+k7Jra7K7hsV7Dz8yu/bDv/NkVt3fTv919jFvPmlndq26urNLa1rgds9rONWMKwgASZUDwvYm208UrfXWlOy37W/a3mj7cdunVT0ngLHRqFuMsyPilcS+8zXQC2OOpDMk3VB8BdDixuIWY5Gk78WAn0s6yPbhY3BeABU1IiBC0irba4vuWMPNkvTCkOc9oocn0BYacYtxVkRstj1D0mrbz0TE0F+iL3tbdcQvuNN6D2g9la8gImJz8XWrpOWShi9f0iNp6OdPsyVtLjkOrfeAFlO1N+cU21MHH0s6V9LwD5hXSPrz4tOM35W0PSJeqnJeAGOj6i3GTEnLi/abEyT9ICLus/0Z6d32eyslLZS0UdKbkj5Z8ZwAxkilgIiI5ySdUrL9xiGPQ9IVVc4DoDmYat0Akbm46/Yz8j+8uf6kb2XXTu3In478+lH5d5VTM6eQS5JqeNso+vIXl/WUydm1T1++f3bttw59MKvuvjcPzj7mtHvyx9q/Y0d27ahNY885d9PODKDlERAAkggIAEkEBIAkAgJAEgEBIImAAJBEQABIIiAAJBEQAJKYap3SP2LJiqSOgw7MqnvjE9uzj3lyd/706Zu3n5xdO2PtO9m1HZMmZtfWIt7OH0Mcf1R27V+e9Z/ZtYd25v3Tv3DdH2cfc/adj2fXeuLo/N02GlcQAJIICABJBASAJAICQBIBASCJgACQREAASKo7IGyfWPTjHPyzw/bnh9UssL19SM2Xqg8ZwFipe6JURDwraZ4k2e6U9KIG+mIM99OIuKDe8wBonkbdYpwj6dcR8XyDjgegBTRqqvViSbcn9p1p+zENdNP6m4h4qqyo1VrvRW/+is67Tzs+q+7m996496LCC735Kz/f8u2F2bWH/dfa7FrXsJpyLStVdxwzO7u2Z2n+cT89bV127Te2vS+rbtZXyzpHNkDHKB23wSpfQdjulnShpH8r2b1O0tERcYqkf5V0V+o4tN4DWk8jbjHOl7QuIrYM3xEROyJiZ/F4paQu24c04JwAxkAjAuISJW4vbB/moi+f7fnF+V5twDkBjIFK70HYnizpg5IuH7JtaF/OiyV91navpLckLS5a8QFoA1V7c74p6eBh24b25bxO0nVVzgGgeZhJCSCJgACQREAASCIgACQREACSWNU6IXbnT7V+/vy8FYpndu7KPuaadw7Lrj38od9m1/bvyh9Dfw0re0dv/ircGz93UnbtL953TXZtp/KnL992z9lZdcesfTj7mOrqzq91/ljd1bz/TLmCAJBEQABIIiAAJBEQAJIICABJBASAJAICQBIBASCJgACQREAASNqnplr373wju3bHxadn117z0Vvzzp99ROnQzh3ZtRs+cUB27dHTT8uujQn504HfOKwru/bgE17Jru1y/v/DXu/vy66deFLe9PSdHzsj+5iTXs2fnj/xN9uya7W1hlUaG7xaNlcQAJKyAsL2MttbbT85ZNt026ttbyi+Tkt875KiZoPtJY0aOIDRl3sFcYuk84Ztu0rSAxExR9IDxfP/x/Z0SUslnSFpvqSlqSAB0HqyAiIiHpI0/KZpkaTBm+9bJX2k5Fs/JGl1RGyLiNckrdbIoAHQoqq8BzEzIl6SpOLrjJKaWZJeGPK8p9gGoA2M9qcYZW+plq5C0mq9OQFUu4LYYvtwSSq+bi2p6ZF05JDnszXQxHcEenMCradKQKyQNPipxBJJd5fU3C/pXNvTijcnzy22AWgDuR9z3i7pYUkn2u6xfamkqyV90PYGDbTfu7qoPd32zZIUEdsk/YOkR4s/Xym2AWgDWe9BRMQliV3nlNSukfTpIc+XSVpW1+gANNU+NdU6+vKn4r41Pf/u67iuvKmwu2toW3xY55vZtasv/pfs2p9++Jj8QdRgVtdr2bWd5e9Tl9rSlz9BvS/yf2b3n3ZTVt2z78mfxv5G5K9q/cUnL8qunbU4fyVyT8pbYT0XU60BJBEQAJIICABJBASAJAICQBIBASCJgACQREAASCIgACQREACS9qmp1or8Kb67Dso/7LETOrPqemqYNjzZ+WOd2pF3fkn6wORN2bXdzl8huZb/02zvz39tfaVLipQ7qCP/7/fl/rx/+p988FPZx5x4wDvZtdPvrGHNk878n2+jcQUBIImAAJBEQABIIiAAJBEQAJIICABJew2IRNu9f7b9jO3HbS+3XfqhoO1Ntp+wvd72mkYOHMDoy7mCuEUju2GtlnRyRLxX0n9L+uIevv/siJgXEfntsgG0hL0GRFnbvYhYFRGDvc5/roF+FwDGmUa8B/EpST9J7AtJq2yvLTpnAWgjlaZa2/57Sb2Svp8oOSsiNtueIWm17WeKK5KyY4166z1PyH+5B/xP/rTdB9/OW/n49ybmtwS54bXTsmtveuT92bX7bcpfeblvYv6U6Pcs2JBd+/Wj78qu3dWfP9X6ay8vyK5d+WDeHe/cbzyffcy+V/JWN5ck1zB92t1d2bWNVvcVhO0lki6Q9PGI8l9yiIjNxdetkpZLmp86Hq33gNZTV0DYPk/SFyRdGBGlDRxsT7E9dfCxBtruPVlWC6A15XzMWdZ27zpJUzVw27De9o1F7RG2VxbfOlPSz2w/JukRSfdGxH2j8ioAjIq93pQn2u59N1G7WdLC4vFzkk6pNDoATcVMSgBJBASAJAICQBIBASCJgACQREAASNqnVrXumDw5u3bavb/Krv3H3iVZdTtn50+vPeTx/BWST3jwl9m16u/Lr63B2huSk2RHOPDY/L+H+944Lrv28aXzsmuPX7U2q65/0sTsY3bsNym7tl1wBQEgiYAAkERAAEgiIAAkERAAkggIAEkEBIAkAgJAEgEBIGmfmkmpjvwFUNWfv2DrAfc8llU3tXzpzlJ2/li9f/4iv4nlQ0t11HDcL/zBvdm1uyN/QeCrH/ij7NoTVuXPKM2e9VjLv5lxiCsIAEn1tt77su0Xi/Uo19temPje82w/a3uj7asaOXAAo6/e1nuSdG3RUm9eRKwcvtN2p6TrJZ0vaa6kS2zPrTJYAGOrrtZ7meZL2hgRz0XELkl3SFpUx3EANEmV9yCuLLp7L7M9rWT/LEkvDHneU2wD0CbqDYgbJB0naZ6klyRdU1JT9vZv8i1025fZXmN7za54q85hAWikugIiIrZERF9E9Eu6SeUt9XokHTnk+WxJm/dwTFrvAS2m3tZ7hw95epHKW+o9KmmO7WNtd0taLGlFPecD0Bx7nShVtN5bIOkQ2z2SlkpaYHueBm4ZNkm6vKg9QtLNEbEwInptXynpfkmdkpZFxFOj8ioAjIpRa71XPF8pacRHoADaw7411boWNUyxdebCpq5h+vaoTfHtq2EMNZg+YWd27WR3Zdce9ZP8adnurOGOeR+fQp2LqdYAkggIAEkEBIAkAgJAEgEBIImAAJBEQABIIiAAJBEQAJIICABJTLUeSy0wvbeW1bJj9+7s2q9+++PZtUvzF8vWMc/XsJhZV/4UbuThCgJAEgEBIImAAJBEQABIIiAAJBEQAJJy1qRcJukCSVsj4uRi2w8lnViUHCTptxExr+R7N0l6XVKfpN6IOL1B4wYwBnLmQdwi6TpJ3xvcEBF/OvjY9jWStu/h+8+OiFfqHSCA5slZtPYh28eU7fPArJs/kfSBxg4LQCuo+h7E+yVtiYgNif0haZXttbYvq3guAGOs6lTrSyTdvof9Z0XEZtszJK22/UzRDHiEIkAuk6RJHTXMxUVtapnu3dubXTrr+nV1DGbv3N2dX1zLqtbIUvffqO0Jkj4q6YepmqJPhiJiq6TlKm/RN1hL6z2gxVSJ3D+U9ExE9JTttD3F9tTBx5LOVXmLPgAtaq8BUbTee1jSibZ7bF9a7FqsYbcXto+wPdhJa6akn9l+TNIjku6NiPsaN3QAo63e1nuKiL8o2fZu672IeE7SKRXHB6CJeFcHQBIBASCJgACQREAASCIgACQREACSWNUaDeFJE5s9BIwCriAAJBEQAJIICABJBASAJAICQBIBASCJgACQREAASCIgACQREACSHBHNHsMItl+W9PywzYdIGo8NeMbr65LG72sbD6/r6Ig4dG9FLRkQZWyvGY+t+8br65LG72sbr6+rDLcYAJIICABJ7RQQ32n2AEbJeH1d0vh9beP1dY3QNu9BABh77XQFAWCMtUVA2D7P9rO2N9q+qtnjaRTbm2w/YXu97TXNHk8VtpfZ3mr7ySHbpttebXtD8XVaM8dYj8Tr+rLtF4uf23rbC5s5xtHU8gFhu1PS9ZLOlzRX0iW25zZ3VA11dkTMGwcfm90i6bxh266S9EBEzJH0QPG83dyika9Lkq4tfm7zImJlyf5xoeUDQgMdwTdGxHMRsUvSHZIWNXlMGCYiHpK0bdjmRZJuLR7fKukjYzqoBki8rn1GOwTELEkvDHneU2wbD0LSKttrbV/W7MGMgpkR8ZIkFV9nNHk8jXSl7ceLW5C2u3XK1Q4B4ZJt4+Wjl7Mi4jQN3D5dYfv3mz0gZLlB0nGS5kl6SdI1zR3O6GmHgOiRdOSQ57MlbW7SWBqq6IauiNgqabkGbqfGky22D5ek4uvWJo+nISJiS0T0RUS/pJs0/n5u72qHgHhU0hzbx9rulrRY0oomj6ky21NsTx18LOlcSU/u+bvazgpJS4rHSyTd3cSxNMxg6BUu0vj7ub2r5RvnRESv7Ssl3S+pU9KyiHiqycNqhJmSltuWBn4OP4iI+5o7pPrZvl3SAkmH2O6RtFTS1ZJ+ZPtSSb+R9LHmjbA+ide1wPY8DdzqbpJ0edMGOMqYSQkgqR1uMQA0CQEBIImAAJBEQABIIiAAJBEQAJIICABJBASApP8Dsc72mIIP3XAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x23ebea48cc0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pyplot.imshow(data['X'][2200].reshape(20,20).T) # 因为mat数据存储的问题，需要转置一下\n",
    "print(data['y'][2200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.mat(data['X'])\n",
    "y = np.mat(data['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas = loadmat('data/init_weights.mat')\n",
    "thetas = [thetas['Theta1'], thetas['Theta2']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing Gradient Checking....\n",
      "Gradient Checked.\n"
     ]
    }
   ],
   "source": [
    "res = train(X,y,checkFlag=False, hiddenNum=1,unitNum=25,thetas=thetas,maxIters=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(499, 0.19417985808360613)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res['iters'], res['error']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readable_predict(idx, X, thetas):\n",
    "    print('predict:', (np.argmax(predict(X[idx], thetas))+1)) # 网络的标签从0开始\n",
    "    print('real tag:', y[idx].ravel()) # 真实的标签把0标记成了10\n",
    "    pyplot.imshow(X[idx].reshape(20,20).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "predict: 10\n",
      "real tag: [[10]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAAD8CAYAAACLgjpEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEmxJREFUeJzt3X+QXWV9x/HPZzcJv7IhCSEBkghpiGigEhEDDLYDoikgY6QKDToaFQx1pBWmv6jOiKPTqY4DmbY4UMU0KBqxtdHMGAIZ+gOZQSFhwo9IMDFNyLJpws8FTchms9/+sSfMdvc+7LP33t1z7+X9msncc8/57jnP2d355Jxzn30eR4QAoJK2shsAoHEREACSCAgASQQEgCQCAkASAQEgiYAAkERAAEgiIAAkjSu7AZVMaDsyjmrvKLsZQMvaf+hV9fS95uHqGjIgjmrv0HlTPlx2M4CW9dBLP86qq+kWw/bFtp+2vc32jRW2H2H77mL7L22fUsvxAIytqgPCdrukb0q6RNJ8SVfZnj+o7GpJL0XEqZKWS/p6tccDMPZquYJYKGlbRGyPiB5JP5S0eFDNYkl3Fsv/Juki28Pe9wBoDLUExExJuwa87yzWVayJiF5J3ZKOq+GYAMZQLQ8pK10JDB5cIqemv9BeJmmZJB3ZNrGGZgGol1quIDolzR7wfpakrlSN7XGSjpX0YqWdRcS3IuLsiDh7QtuRNTQLQL3UEhCPSJpne47tCZKWSFozqGaNpKXF8kck/UcwhBXQNKq+xYiIXtvXSbpXUrukFRGx2fZXJG2IiDWSviPpe7a3qf/KYUk9Gg1gbNTUUSoi1kpaO2jdlwYsvybpilqOAaA8/C0GgCQCAkASAQEgiYAAkERAAEgiIAAkERAAkggIAEkEBIAkAgJAEgEBIKkhB61FExrBH+lGz8H8/R4cQW17e3apjzwis/DNPQAaVxAAkggIAEkEBIAkAgJAEgEBIImAAJBUy8xas23/p+2nbG+2/fkKNRfY7ra9qfj3pUr7AtCYaukH0SvpLyLiUdsdkjbaXh8RvxpU9/OIuKyG4wAoSdVXEBGxOyIeLZZflfSUhs6sBaCJ1eUZRDFr9zsl/bLC5vNsP2b7Htun1+N4AMZGzV2tbU+U9GNJ10fEK4M2Pyrp5Ij4re1LJf1E0rzEfph6byyMpEv0aweya9smdWTX7n/33Oza/z1nfHZtx878czt+3fasujjQk71PtbfeM/+azsj2ePWHw/cj4t8Hb4+IVyLit8XyWknjbU+rtC+m3gMaTy2fYlj9M2c9FRG3JGpOKOpke2FxvBeqPSaAsVXLLcb5kj4u6Qnbm4p1X5D0FkmKiNvVPx/nZ233StovaQlzcwLNo5a5OR+U9IZ/CxsRt0q6tdpjAChX6z1VAVA3BASAJAICQBIBASCJgACQREAASGJU61YwglGi3ZHfjX3nsrdl105/77PZtR+ZuS679oqOLdm1mw5Mzq7960mfyao76XtPZe8z9r+WXZs9qrZU6sjaXEEASCIgACQREACSCAgASQQEgCQCAkASAQEgiYAAkERAAEiiJ2WDin3782vfPie79tW/y9/v6rd/I7v2+LbR6e33cl/+AGTzJ7yUXbvyhuVZdZ9Z9PHsfbbdfVx27bT1eYPmSlIczO8pW+9el1xBAEiqOSBs77D9RDG13oYK2237H21vs/247bNqPSaAsVGvW4wLI+L5xLZL1D8XxjxJ50i6rXgF0ODG4hZjsaTvRr9fSJps+8QxOC6AGtUjIELSfbY3FrNjDTZT0q4B7zvFHJ5AU6jHLcb5EdFle7qk9ba3RMQDA7ZXeqw65NE0U+8BjafmK4iI6Cpe90paLWnhoJJOSbMHvJ8lqavCfph6D2gwtc7NeYztjsPLkhZJenJQ2RpJnyg+zThXUndE7K7luADGRq23GDMkrS6m3xwn6QcRsc72n0qvT7+3VtKlkrZJ2ifpUzUeE8AYqSkgImK7pDMrrL99wHJI+lwtxwFQDrpaj6E40JNde+A987NrJ36xM7t27dzV2bUHRzDN8qvRl137nZfyu8FcPunR7NoZ7fldkk8a15tVd8+Z/5K9z52nj8+u/cLWa7Jr23+1I7tWE/LbkIOu1gCSCAgASQQEgCQCAkASAQEgiYAAkERAAEgiIAAkERAAkggIAEl0ta6D3BGoX7nsHdn7/OiXf5Zdu3TS1uzarkP5XaKX/uoT2bUvbpyeXXvyPfkja6/6s3fl155zR3btj14ePCpBZZcfuzF7n6eNP5Bd2ze+Pbu2LfL7vNd7bHGuIAAkERAAkggIAEkEBIAkAgJAEgEBIImAAJBUdUDYPq2Yj/Pwv1dsXz+o5gLb3QNqvlR7kwGMlao7SkXE05IWSJLtdknPqn9ejMF+HhGXVXscAOWp1y3GRZJ+ExE767Q/AA2gXl2tl0haldh2nu3H1D+b1l9GxOZKRQ039d4IuiT3nTE3q27O9Vuy9zmS7tNPH8zP+Y9+/4bs2lNvy8/7yS8/nl2rtvz2zrv+6OzaZYs/n1173BP7sup23zwpe5//MGtddm2zqPkKwvYESR+U9K8VNj8q6eSIOFPSP0n6SWo/TL0HNJ563GJcIunRiNgzeENEvBIRvy2W10oab3taHY4JYAzUIyCuUuL2wvYJLubls72wON4LdTgmgDFQ0zMI20dLer+kawesGzgv50ckfdZ2r6T9kpYUU/EBaAK1zs25T9Jxg9YNnJfzVkm31nIMAOWhJyWAJAICQBIBASCJgACQREAASGJU64S+fXldcSXpN1cek1X3vdn5I1W/3Jf/afBH7/qr7Nq5t+R39x7J59E+anR6v0Zvb3bttG8/nF3bPm9OVt3co5/P3ufRbeOza5vlv+YmaSaAMhAQAJIICABJBASAJAICQBIBASCJgACQREAASCIgACQREACS3lxdrXsOZpceOnd+du0nF/1XVt1E53fFvejJq7JrT/3nZ7JrI/JH69a48n89+rpfza5tO2Nedu2BW/K60l9/3Mbsfd71Sl73bUka130gu7YYtbEUXEEASMoKCNsrbO+1/eSAdVNtr7e9tXidkvjapUXNVttL69VwAKMv9wpipaSLB627UdL9ETFP0v3F+//H9lRJN0k6R9JCSTelggRA48kKiIh4QNKLg1YvlnRnsXynpA9V+NI/krQ+Il6MiJckrdfQoAHQoGp5BjEjInZLUvE6vULNTEm7BrzvLNYBaAKj/Zi60uPXiuOQNNzcnABquoLYY/tESSpe91ao6ZQ0e8D7WeqfxHcI5uYEGk8tAbFG0uFPJZZK+mmFmnslLbI9pXg4uahYB6AJ5H7MuUrSQ5JOs91p+2pJX5P0fttb1T/93teK2rNt3yFJEfGipK9KeqT495ViHYAmkPUMIiJS3fouqlC7QdI1A96vkLSiqtYBKFX5fWnHUPT0ZNc+syj/Ocj1Uzdl1XX35Y/QvO/eGdm1k7p+mV3bPnVydu2IjGD0abW3Z5c+9+l3ZddOu3LX8EWFdW+rdEc81Le635q9z7u++oHs2ik7nsqu1YQRjJZdZ3S1BpBEQABIIiAAJBEQAJIICABJBASAJAICQBIBASCJgACQREAASGr+rtaH8kdpbptxfHbt1LMq/fV6ZQczR4o+ui2/i/GhI7JLpb5D2aWx/7X82gP5Iy+3nTJ7+KLCtmtOyK5dveSW7Nqpbfnfh4/9z+Ksuq7lp2bvc/J9m7Nry+w+PRJcQQBIIiAAJBEQAJIICABJBASAJAICQNKwAZGYdu8btrfYftz2atsVhymyvcP2E7Y32d5Qz4YDGH05VxArNXQ2rPWSzoiId0j6taS/fYOvvzAiFkTE2dU1EUBZhg2IStPuRcR9EXF4EMJfqH++CwAtph7PID4t6Z7EtpB0n+2NxcxZAJpITV2tbX9RUq+k7ydKzo+ILtvTJa23vaW4Iqm0r+qm3juU3722d0b+iM5//9bUKQ31u8yu1l29+f2nD3ZUnKGwovb5+SMv/25u/vfg+TPyfz3OXfx4du13Z96VXfv1596TXXvPj87Lrn3L6ryu9B2dTw5fVPBRrTcjXNVXELaXSrpM0sciouJvc0R0Fa97Ja2WtDC1P6beAxpPVQFh+2JJfyPpgxGxL1FzjO2Ow8vqn3YvP44BlC7nY85K0+7dKqlD/bcNm2zfXtSeZHtt8aUzJD1o+zFJD0v6WUSsG5WzADAqhr3JTEy7951EbZekS4vl7ZLOrKl1AEpFT0oASQQEgCQCAkASAQEgiYAAkERAAEhq/lGt25xd2t69P7t21QvnZNcuP+m/s+q6+3qHLyp89cofZNdu/MCc7Np3Hr0zu/aKiS9k1/Yqv8v79oP5P7MHlp+bXTtr1cPZteroyCprxe7TI8EVBIAkAgJAEgEBIImAAJBEQABIIiAAJBEQAJIICABJBASApObvSTku/xSic3d27YY7FmTX3v7ne7LqPnVs/oh7p0/szq599xEPZtc+1nNCdu0Nu/N7k75n0q+za1c+e3527eStv8uubZt4THat2vm/MQffJQBJ1U6992XbzxbjUW6yfWniay+2/bTtbbZvrGfDAYy+aqfek6TlxZR6CyJi7eCNttslfVPSJZLmS7rK9vxaGgtgbFU19V6mhZK2RcT2iOiR9ENJi6vYD4CS1PIM4rpidu8VtqdU2D5T0q4B7zuLdQCaRLUBcZukuZIWSNot6eYKNZX+6D85n5ztZbY32N7Q0/dalc0CUE9VBURE7ImIQxHRJ+nbqjylXqek2QPez5LU9Qb7ZOo9oMFUO/XeiQPeXq7KU+o9Imme7Tm2J0haImlNNccDUI5hexkVU+9dIGma7U5JN0m6wPYC9d8y7JB0bVF7kqQ7IuLSiOi1fZ2keyW1S1oREZtH5SwAjIpRm3qveL9W0pCPQAE0B0cknxuW5tjxx8d5Uz5cbiNG8H3p+f1Tsup2ve+I7H0eOjL/+BOfyb9T7NiVP7jsxK353b27T5+cXXvs5peza/1M8rHVUBPG59e+yT300o/VffC5YUcPpqs1gCQCAkASAQEgiYAAkERAAEgiIAAkERAAkggIAEkEBIAkAgJAUvOPaj1aPGwv1NeN37g1q+73Hs7v5jxqRnBeHp//63Hszmfz9zuSEaXpPl0qriAAJBEQAJIICABJBASAJAICQBIBASApZ0zKFZIuk7Q3Is4o1t0t6bSiZLKklyNiyGy3tndIelXSIUm9EXF2ndoNYAzkfNC9UtKtkr57eEVE/MnhZds3S3qjsckujIjnq20ggPLkDFr7gO1TKm2zbUlXSnpvfZsFoBHU+gziDyTtiYhUV8KQdJ/tjbaX1XgsAGOs1q7WV0la9Qbbz4+ILtvTJa23vaWYDHiIIkCWSdKRbRNrbNbY8hETym5C6Uyn/ZZU9RWE7XGS/ljS3amaYp4MRcReSatVeYq+w7VMvQc0mFpuMd4naUtEdFbaaPsY2x2HlyUtUuUp+gA0qGEDoph67yFJp9nutH11sWmJBt1e2D7J9uGZtGZIetD2Y5IelvSziFhXv6YDGG3VTr2niPhkhXWvT70XEdslnVlj+wCUiJ6UAJIICABJBASAJAICQBIBASCJgACQREAASCIgACQREACSCAgASQQEgCQCAkASAQEgiYAAkERAAEgiIAAkERAAkhwRZbdhCNvPSdo5aPU0Sa04AU+rnpfUuufWCud1ckQcP1xRQwZEJbY3tOLUfa16XlLrnlurnlcl3GIASCIgACQ1U0B8q+wGjJJWPS+pdc+tVc9riKZ5BgFg7DXTFQSAMdYUAWH7YttP295m+8ay21MvtnfYfsL2Jtsbym5PLWyvsL3X9pMD1k21vd721uJ1SpltrEbivL5s+9ni57bJ9qVltnE0NXxA2G6X9E1Jl0iaL+kq2/PLbVVdXRgRC1rgY7OVki4etO5GSfdHxDxJ9xfvm81KDT0vSVpe/NwWRMTaCttbQsMHhPpnBN8WEdsjokfSDyUtLrlNGCQiHpD04qDViyXdWSzfKelDY9qoOkic15tGMwTETEm7BrzvLNa1gpB0n+2NtpeV3ZhRMCMidktS8Tq95PbU03W2Hy9uQZru1ilXMwSEK6xrlY9ezo+Is9R/+/Q5239YdoOQ5TZJcyUtkLRb0s3lNmf0NENAdEqaPeD9LEldJbWlrorZ0BUReyWtVv/tVCvZY/tESSpe95bcnrqIiD0RcSgi+iR9W633c3tdMwTEI5Lm2Z5je4KkJZLWlNymmtk+xnbH4WVJiyQ9+cZf1XTWSFpaLC+V9NMS21I3h0OvcLla7+f2unFlN2A4EdFr+zpJ90pql7QiIjaX3Kx6mCFptW2p/+fwg4hYV26Tqmd7laQLJE2z3SnpJklfk/Qj21dLekbSFeW1sDqJ87rA9gL13+rukHRtaQ0cZfSkBJDUDLcYAEpCQABIIiAAJBEQAJIICABJBASAJAICQBIBASDp/wC2NgGJccB+JgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x23ebefa3358>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "readable_predict(235, X, res['thetas'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
